import os
import sys
import json
import copy
import time
import random
import base64

sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

import cv2
import numpy as np
from ai_hub import inferServer
from tools.infer import utility, predict_cls, predict_det, predict_rec
from ppocr.utils.logging import get_logger


logger = get_logger()


class TextSystem(object):
    def __init__(self, args):
        self.text_detector = predict_det.TextDetector(args)
        self.text_recognizer = predict_rec.TextRecognizer(args)
        self.use_angle_cls = args.use_angle_cls
        self.drop_score = args.drop_score
        if self.use_angle_cls:
            self.text_classifier = predict_cls.TextClassifier(args)

    def get_rotate_crop_image(self, img, points):
        '''
        img_height, img_width = img.shape[0:2]
        left = int(np.min(points[:, 0]))
        right = int(np.max(points[:, 0]))
        top = int(np.min(points[:, 1]))
        bottom = int(np.max(points[:, 1]))
        img_crop = img[top:bottom, left:right, :].copy()
        points[:, 0] = points[:, 0] - left
        points[:, 1] = points[:, 1] - top
        '''
        img_crop_width = int(
            max(
                np.linalg.norm(points[0] - points[1]),
                np.linalg.norm(points[2] - points[3])))
        img_crop_height = int(
            max(
                np.linalg.norm(points[0] - points[3]),
                np.linalg.norm(points[1] - points[2])))
        pts_std = np.float32([[0, 0], [img_crop_width, 0],
                              [img_crop_width, img_crop_height],
                              [0, img_crop_height]])
        M = cv2.getPerspectiveTransform(points, pts_std)
        dst_img = cv2.warpPerspective(
            img,
            M, (img_crop_width, img_crop_height),
            borderMode=cv2.BORDER_REPLICATE,
            flags=cv2.INTER_CUBIC)
        dst_img_height, dst_img_width = dst_img.shape[0:2]
        if dst_img_height * 1.0 / dst_img_width >= 1.5:
            dst_img = np.rot90(dst_img)
        return dst_img

    def sort_boxes(self, dt_boxes):
        """
        Sort text boxes in order from top to bottom, left to right
        args:
            dt_boxes(array):detected text boxes with shape [4, 2]
        return:
            sorted boxes(array) with shape [4, 2]
        """
        num_boxes = dt_boxes.shape[0]
        sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
        _boxes = list(sorted_boxes)

        for i in range(num_boxes - 1):
            if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
                    (_boxes[i + 1][0][0] < _boxes[i][0][0]):
                tmp = _boxes[i]
                _boxes[i] = _boxes[i + 1]
                _boxes[i + 1] = tmp
        return _boxes

    def __call__(self, img):
        ori_im = img.copy()
        dt_boxes, elapse = self.text_detector(img)
        # logger.info("dt_boxes num : {}, elapse : {}".format( len(dt_boxes), elapse))
        if dt_boxes is None:
            return None, None
        img_crop_list = []

        dt_boxes = self.sort_boxes(dt_boxes)

        for bno in range(len(dt_boxes)):
            tmp_box = copy.deepcopy(dt_boxes[bno])
            img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
            img_crop_list.append(img_crop)
        if self.use_angle_cls:
            img_crop_list, angle_list, elapse = self.text_classifier(
                img_crop_list)
            # logger.info("cls num  : {}, elapse : {}".format(len(img_crop_list), elapse))

        rec_res, elapse = self.text_recognizer(img_crop_list)
        # logger.info("rec_res num  : {}, elapse : {}".format(len(rec_res), elapse))
        filter_boxes, filter_rec_res = [], []
        for box, rec_reuslt in zip(dt_boxes, rec_res):
            text, score = rec_reuslt
            if score >= self.drop_score:
                filter_boxes.append(box)
                filter_rec_res.append(rec_reuslt)
        return filter_boxes, filter_rec_res


class tccServer(inferServer):
    def __init__(self, text_sys: TextSystem):
        super().__init__(text_sys)
        logger.info('init tcc server')
        self.model = text_sys

    def pre_process(self, request):
        data = request.get_data()
        json_data = json.loads(data.decode('utf-8'))
        image_base64_string = json_data.get("img")[0]
        image_data = base64.b64decode(image_base64_string)
        img = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
        name = json_data.get('index')[0]
        return img, name
    
    def predict(self, data):
        img, key = data
        # logger.info('image name: {}, shape: {}'.format(key, img.shape))
        dt_boxes, rec_res = self.model(img)
        results = {}
        results[key] = {}
        results[key]['pointsList'] = []
        results[key]['transcriptionsList'] = []

        for box, (text, score) in zip(dt_boxes, rec_res):
            results[key]['pointsList'].append([float(x) for x in list(box.reshape(-1))])
            results[key]['transcriptionsList'].append(text)

        return json.dumps(results, ensure_ascii=False)


if __name__ == '__main__':
    project_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
    args = utility.parse_args()

    text_sys = TextSystem(args)
    myserver = tccServer(text_sys)
    myserver.run(debuge=True)
    # myserver.run(ip='0.0.0.0', debuge=False)


